1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
| import torch from torch import nn from torch import optim from torch.utils.data import DataLoader from torchvision import transforms from torchvision import datasets from torch.nn import functional as F class Lenet5(nn.Module): """ for cifar10 dataset. """ def __init__(self): super(Lenet5,self).__init__() self.conv_unit = nn.Sequential( nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5,stride=1,padding=0), nn.AvgPool2d(kernel_size=2,stride=2,padding=0), nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride=1,padding=0), nn.AvgPool2d(kernel_size=2,stride=2,padding=0), ) self.fc_unit = nn.Sequential( nn.Linear(in_features=16*5*5,out_features=120), nn.ReLU(), nn.Linear(120,84), nn.ReLU(), nn.Linear(84,10) ) self.criteon = nn.CrossEntropyLoss() def forward(self, x): """ :param x: [batch,3, 32,32] :return: """ batch_size = x.size(0) x = self.conv_unit(x) x = x.view(batch_size,16*5*5) logits = self.fc_unit(x) def main(): batch_size = 32 cifar_train = datasets.CIFAR10('./data/cifar10', train=True, transform=transforms.Compose( [ transforms.Resize((32,32)), transforms.ToTensor() ] ), download=True) cifar_train = DataLoader(cifar_train,batch_size=batch_size,shuffle=True) cifar_test = datasets.CIFAR10('./data/cifar10', train=False, transform=transforms.Compose( [ transforms.Resize((32, 32)), transforms.ToTensor() ] ), download=True) cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=False ) model = Lenet5() print(model) optiminzer = optim.Adam(model.parameters(),lr=1e-3) criteon = nn.CrossEntropyLoss() for epoch in range(20): model.train() for batchidx ,(x,label) in enumerate(cifar_train): logits = model(x) loss = criteon(logits,label) optiminzer.zero_grad() loss.backward() optiminzer.step() print("loss:",epoch, loss.item()) model.eval() with torch.no_grad(): total_correct = 0 total_num = 0 for x, label in cifar_test: logits = model(x) pred = logits.argmax(dim=1) total_correct += torch.eq(pred,label).float().sum().item() total_num += x.size(0) acc = total_correct / total_num print("acc",epoch,acc) if __name__ == '__main__': main()
|